{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2020-06-15T14:55:05.325754Z",
     "start_time": "2020-06-15T14:55:04.214070Z"
    }
   },
   "outputs": [],
   "source": [
    "import torch\n",
    "from torch import nn\n",
    "from torch.nn import init\n",
    "import numpy as np\n",
    "import sys\n",
    "sys.path.append(\"..\")\n",
    "import d2lzh_pytorch as d2l"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2020-06-15T14:55:31.922864Z",
     "start_time": "2020-06-15T14:55:31.884106Z"
    }
   },
   "outputs": [],
   "source": [
    "batch_size=256\n",
    "train_iter,test_iter, = d2l.load_data_fashion_mnist(batch_size)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2020-06-15T14:57:59.030520Z",
     "start_time": "2020-06-15T14:57:59.025218Z"
    }
   },
   "outputs": [],
   "source": [
    "num_inputs = 784\n",
    "num_outputs = 10\n",
    "\n",
    "class LinearNet(nn.Module):\n",
    "    def __init__(self,num_inputs,num_outputs):\n",
    "        super(LinearNet,self).__init__()\n",
    "        self.linear = nn.Linear(num_inputs,num_outputs)\n",
    "    def forward(self,x):\n",
    "        y = self.linear(x.view(x.shape[0],-1))\n",
    "        return y"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2020-06-15T14:58:40.644705Z",
     "start_time": "2020-06-15T14:58:40.641308Z"
    }
   },
   "outputs": [],
   "source": [
    "net = LinearNet(num_inputs,num_outputs)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2020-06-15T14:58:45.318684Z",
     "start_time": "2020-06-15T14:58:45.313241Z"
    }
   },
   "outputs": [],
   "source": [
    "# 本函数已保存在d2lzh_pytorch包中方便以后使用\n",
    "class FlattenLayer(nn.Module):\n",
    "    def __init__(self):\n",
    "        super(FlattenLayer, self).__init__()\n",
    "    def forward(self, x): # x shape: (batch, *, *, ...)\n",
    "        return x.view(x.shape[0], -1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2020-06-15T15:00:01.596502Z",
     "start_time": "2020-06-15T15:00:01.591866Z"
    }
   },
   "outputs": [],
   "source": [
    "from collections import OrderedDict\n",
    "\n",
    "net = nn.Sequential(\n",
    "    OrderedDict([\n",
    "        ('flatten',FlattenLayer()),\n",
    "        ('linear',nn.Linear(num_inputs,num_outputs))\n",
    "    ])\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2020-06-15T15:00:19.826710Z",
     "start_time": "2020-06-15T15:00:19.818574Z"
    }
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "Parameter containing:\n",
       "tensor([0., 0., 0., 0., 0., 0., 0., 0., 0., 0.], requires_grad=True)"
      ]
     },
     "execution_count": 7,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "init.normal_(net.linear.weight, mean=0, std=0.01)\n",
    "init.constant_(net.linear.bias, val=0) "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2020-06-15T15:00:39.646453Z",
     "start_time": "2020-06-15T15:00:39.643191Z"
    }
   },
   "outputs": [],
   "source": [
    "loss = nn.CrossEntropyLoss()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2020-06-15T15:00:59.910725Z",
     "start_time": "2020-06-15T15:00:59.907186Z"
    }
   },
   "outputs": [],
   "source": [
    "optimizer = torch.optim.SGD(net.parameters(),lr=0.1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2020-06-15T15:01:58.021682Z",
     "start_time": "2020-06-15T15:01:42.035130Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch 1, loss 0.0031, train acc 0.748, test acc 0.783\n",
      "epoch 2, loss 0.0022, train acc 0.813, test acc 0.802\n",
      "epoch 3, loss 0.0021, train acc 0.826, test acc 0.801\n",
      "epoch 4, loss 0.0020, train acc 0.832, test acc 0.822\n",
      "epoch 5, loss 0.0019, train acc 0.837, test acc 0.821\n"
     ]
    }
   ],
   "source": [
    "num_epochs = 5\n",
    "d2l.train_ch3(net,train_iter,test_iter,loss,num_epochs,batch_size,None,None,optimizer)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3.7.5 64-bit ('numpy': virtualenv)",
   "language": "python",
   "name": "python37564bitnumpyvirtualenv325e501f273d447b90cac77d0bfa67d7"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.5"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
